{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp vision.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.torch_basics import *\n",
    "from fastai.data.all import *\n",
    "from fastai.vision.core import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *\n",
    "# from fastai.vision.augment import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Vision data\n",
    "\n",
    "> Helper functions to get data in a `DataLoaders` in the vision application and higher class `ImageDataLoaders`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The main classes defined in this module are `ImageDataLoaders` and `SegmentationDataLoaders`, so you probably want to jump to their definitions. They provide factory methods that are a great way to quickly get your data ready for training, see the [vision tutorial](http://docs.fast.ai/tutorial.vision) for examples."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(subplots)\n",
    "def get_grid(n, nrows=None, ncols=None, add_vert=0, figsize=None, double=False, title=None, return_fig=False, **kwargs):\n",
    "    \"Return a grid of `n` axes, `rows` by `cols`\"\n",
    "    nrows = nrows or int(math.sqrt(n))\n",
    "    ncols = ncols or int(np.ceil(n/nrows))\n",
    "    if double: ncols*=2 ; n*=2\n",
    "    fig,axs = subplots(nrows, ncols, figsize=figsize, **kwargs)\n",
    "    axs = [ax if i<n else ax.set_axis_off() for i, ax in enumerate(axs.flatten())][:n]\n",
    "    if title is not None: fig.suptitle(title, weight='bold', size=14)\n",
    "    return (fig,axs) if return_fig else axs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is used by the type-dispatched versions of `show_batch` and `show_results` for the vision application. By default, there will be `int(math.sqrt(n))` rows and `ceil(n/rows)` columns. `double` will double the number of columns and `n`. The default `figsize` is `(cols*imsize, rows*imsize+add_vert)`. If a `title` is passed it is set to the figure. `sharex`, `sharey`, `squeeze`, `subplot_kw` and `gridspec_kw` are all passed down to `plt.subplots`. If `return_fig` is `True`, returns `fig,axs`, otherwise just `axs`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def clip_remove_empty(bbox, label):\n",
    "    \"Clip bounding boxes with image border and label background the empty ones\"\n",
    "    bbox = torch.clamp(bbox, -1, 1)\n",
    "    empty = ((bbox[...,2] - bbox[...,0])*(bbox[...,3] - bbox[...,1]) <= 0.)\n",
    "    return (bbox[~empty], label[~empty])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bb = tensor([[-2,-0.5,0.5,1.5], [-0.5,-0.5,0.5,0.5], [1,0.5,0.5,0.75], [-0.5,-0.5,0.5,0.5], [-2, -0.5, -1.5, 0.5]])\n",
    "bb,lbl = clip_remove_empty(bb, tensor([1,2,3,2,5]))\n",
    "test_eq(bb, tensor([[-1,-0.5,0.5,1.], [-0.5,-0.5,0.5,0.5], [-0.5,-0.5,0.5,0.5]]))\n",
    "test_eq(lbl, tensor([1,2,2]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def bb_pad(samples, pad_idx=0):\n",
    "    \"Function that collect `samples` of labelled bboxes and adds padding with `pad_idx`.\"\n",
    "    samples = [(s[0], *clip_remove_empty(*s[1:])) for s in samples]\n",
    "    max_len = max([len(s[2]) for s in samples])\n",
    "    def _f(img,bbox,lbl):\n",
    "        bbox = torch.cat([bbox,bbox.new_zeros(max_len-bbox.shape[0], 4)])\n",
    "        lbl  = torch.cat([lbl, lbl .new_zeros(max_len-lbl .shape[0])+pad_idx])\n",
    "        return img,bbox,lbl\n",
    "    return [_f(*s) for s in samples]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img1,img2 = TensorImage(torch.randn(16,16,3)),TensorImage(torch.randn(16,16,3))\n",
    "bb1 = tensor([[-2,-0.5,0.5,1.5], [-0.5,-0.5,0.5,0.5], [1,0.5,0.5,0.75], [-0.5,-0.5,0.5,0.5]])\n",
    "lbl1 = tensor([1, 2, 3, 2])\n",
    "bb2 = tensor([[-0.5,-0.5,0.5,0.5], [-0.5,-0.5,0.5,0.5]])\n",
    "lbl2 = tensor([2, 2])\n",
    "samples = [(img1, bb1, lbl1), (img2, bb2, lbl2)]\n",
    "res = bb_pad(samples)\n",
    "non_empty = tensor([True,True,False,True])\n",
    "test_eq(res[0][0], img1)\n",
    "test_eq(res[0][1], tensor([[-1,-0.5,0.5,1.], [-0.5,-0.5,0.5,0.5], [-0.5,-0.5,0.5,0.5]]))\n",
    "test_eq(res[0][2], tensor([1,2,2]))\n",
    "test_eq(res[1][0], img2)\n",
    "test_eq(res[1][1], tensor([[-0.5,-0.5,0.5,0.5], [-0.5,-0.5,0.5,0.5], [0,0,0,0]]))\n",
    "test_eq(res[1][2], tensor([2,2,0]))      "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Show methods -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def show_batch(x:TensorImage, y, samples, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):\n",
    "    if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)\n",
    "    ctxs = show_batch[object](x, y, samples, ctxs=ctxs, max_n=max_n, **kwargs)\n",
    "    return ctxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def show_batch(x:TensorImage, y:TensorImage, samples, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):\n",
    "    if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, add_vert=1, figsize=figsize, double=True)\n",
    "    for i in range(2):\n",
    "        ctxs[i::2] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs[i::2],range(max_n))]\n",
    "    return ctxs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `TransformBlock`s for vision"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These are the blocks the vision application provide for the [data block API](http://docs.fast.ai/data.block)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def ImageBlock(cls=PILImage):\n",
    "    \"A `TransformBlock` for images of `cls`\"\n",
    "    return TransformBlock(type_tfms=cls.create, batch_tfms=IntToFloatTensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def MaskBlock(codes=None):\n",
    "    \"A `TransformBlock` for segmentation masks, potentially with `codes`\"\n",
    "    return TransformBlock(type_tfms=PILMask.create, item_tfms=AddMaskCodes(codes=codes), batch_tfms=IntToFloatTensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "PointBlock = TransformBlock(type_tfms=TensorPoint.create, item_tfms=PointScaler)\n",
    "BBoxBlock = TransformBlock(type_tfms=TensorBBox.create, item_tfms=PointScaler, dls_kwargs = {'before_batch': bb_pad})\n",
    "\n",
    "PointBlock.__doc__ = \"A `TransformBlock` for points in an image\"\n",
    "BBoxBlock.__doc__  = \"A `TransformBlock` for bounding boxes in an image\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"PointBlock\" class=\"doc_header\"><code>PointBlock</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "A [`TransformBlock`](/data.block.html#TransformBlock) for points in an image"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(PointBlock, name='PointBlock')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"BBoxBlock\" class=\"doc_header\"><code>BBoxBlock</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "A [`TransformBlock`](/data.block.html#TransformBlock) for bounding boxes in an image"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(BBoxBlock, name='BBoxBlock')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def BBoxLblBlock(vocab=None, add_na=True):\n",
    "    \"A `TransformBlock` for labeled bounding boxes, potentially with `vocab`\"\n",
    "    return TransformBlock(type_tfms=MultiCategorize(vocab=vocab, add_na=add_na), item_tfms=BBoxLabeler)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If `add_na` is `True`, a new category is added for NaN (that will represent the background class)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ImageDataLoaders -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class ImageDataLoaders(DataLoaders):\n",
    "    \"Basic wrapper around several `DataLoader`s with factory methods for computer vision problems\"\n",
    "    @classmethod\n",
    "    @delegates(DataLoaders.from_dblock)\n",
    "    def from_folder(cls, path, train='train', valid='valid', valid_pct=None, seed=None, vocab=None, item_tfms=None,\n",
    "                    batch_tfms=None, **kwargs):\n",
    "        \"Create from imagenet style dataset in `path` with `train` and `valid` subfolders (or provide `valid_pct`)\"\n",
    "        splitter = GrandparentSplitter(train_name=train, valid_name=valid) if valid_pct is None else RandomSplitter(valid_pct, seed=seed)\n",
    "        get_items = get_image_files if valid_pct else partial(get_image_files, folders=[train, valid])\n",
    "        dblock = DataBlock(blocks=(ImageBlock, CategoryBlock(vocab=vocab)),\n",
    "                           get_items=get_items,\n",
    "                           splitter=splitter,\n",
    "                           get_y=parent_label,\n",
    "                           item_tfms=item_tfms,\n",
    "                           batch_tfms=batch_tfms)\n",
    "        return cls.from_dblock(dblock, path, path=path, **kwargs)\n",
    "\n",
    "    @classmethod\n",
    "    @delegates(DataLoaders.from_dblock)\n",
    "    def from_path_func(cls, path, fnames, label_func, valid_pct=0.2, seed=None, item_tfms=None, batch_tfms=None, **kwargs):\n",
    "        \"Create from list of `fnames` in `path`s with `label_func`\"\n",
    "        dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),\n",
    "                           splitter=RandomSplitter(valid_pct, seed=seed),\n",
    "                           get_y=label_func,\n",
    "                           item_tfms=item_tfms,\n",
    "                           batch_tfms=batch_tfms)\n",
    "        return cls.from_dblock(dblock, fnames, path=path, **kwargs)\n",
    "\n",
    "    @classmethod\n",
    "    def from_name_func(cls, path, fnames, label_func, **kwargs):\n",
    "        \"Create from the name attrs of `fnames` in `path`s with `label_func`\"\n",
    "        f = using_attr(label_func, 'name')\n",
    "        return cls.from_path_func(path, fnames, f, **kwargs)\n",
    "\n",
    "    @classmethod\n",
    "    def from_path_re(cls, path, fnames, pat, **kwargs):\n",
    "        \"Create from list of `fnames` in `path`s with re expression `pat`\"\n",
    "        return cls.from_path_func(path, fnames, RegexLabeller(pat), **kwargs)\n",
    "\n",
    "    @classmethod\n",
    "    @delegates(DataLoaders.from_dblock)\n",
    "    def from_name_re(cls, path, fnames, pat, **kwargs):\n",
    "        \"Create from the name attrs of `fnames` in `path`s with re expression `pat`\"\n",
    "        return cls.from_name_func(path, fnames, RegexLabeller(pat), **kwargs)\n",
    "\n",
    "    @classmethod\n",
    "    @delegates(DataLoaders.from_dblock)\n",
    "    def from_df(cls, df, path='.', valid_pct=0.2, seed=None, fn_col=0, folder=None, suff='', label_col=1, label_delim=None,\n",
    "                y_block=None, valid_col=None, item_tfms=None, batch_tfms=None, **kwargs):\n",
    "        \"Create from `df` using `fn_col` and `label_col`\"\n",
    "        pref = f'{Path(path) if folder is None else Path(path)/folder}{os.path.sep}'\n",
    "        if y_block is None:\n",
    "            is_multi = (is_listy(label_col) and len(label_col) > 1) or label_delim is not None\n",
    "            y_block = MultiCategoryBlock if is_multi else CategoryBlock\n",
    "        splitter = RandomSplitter(valid_pct, seed=seed) if valid_col is None else ColSplitter(valid_col)\n",
    "        dblock = DataBlock(blocks=(ImageBlock, y_block),\n",
    "                           get_x=ColReader(fn_col, pref=pref, suff=suff),\n",
    "                           get_y=ColReader(label_col, label_delim=label_delim),\n",
    "                           splitter=splitter,\n",
    "                           item_tfms=item_tfms,\n",
    "                           batch_tfms=batch_tfms)\n",
    "        return cls.from_dblock(dblock, df, path=path, **kwargs)\n",
    "\n",
    "    @classmethod\n",
    "    def from_csv(cls, path, csv_fname='labels.csv', header='infer', delimiter=None, **kwargs):\n",
    "        \"Create from `path/csv_fname` using `fn_col` and `label_col`\"\n",
    "        df = pd.read_csv(Path(path)/csv_fname, header=header, delimiter=delimiter)\n",
    "        return cls.from_df(df, path=path, **kwargs)\n",
    "\n",
    "    @classmethod\n",
    "    @delegates(DataLoaders.from_dblock)\n",
    "    def from_lists(cls, path, fnames, labels, valid_pct=0.2, seed:int=None, y_block=None, item_tfms=None, batch_tfms=None,\n",
    "                   **kwargs):\n",
    "        \"Create from list of `fnames` and `labels` in `path`\"\n",
    "        if y_block is None:\n",
    "            y_block = MultiCategoryBlock if is_listy(labels[0]) and len(labels[0]) > 1 else (\n",
    "                RegressionBlock if isinstance(labels[0], float) else CategoryBlock)\n",
    "        dblock = DataBlock.from_columns(blocks=(ImageBlock, y_block),\n",
    "                           splitter=RandomSplitter(valid_pct, seed=seed),\n",
    "                           item_tfms=item_tfms,\n",
    "                           batch_tfms=batch_tfms)\n",
    "        return cls.from_dblock(dblock, (fnames, labels), path=path, **kwargs)\n",
    "\n",
    "ImageDataLoaders.from_csv = delegates(to=ImageDataLoaders.from_df)(ImageDataLoaders.from_csv)\n",
    "ImageDataLoaders.from_name_func = delegates(to=ImageDataLoaders.from_path_func)(ImageDataLoaders.from_name_func)\n",
    "ImageDataLoaders.from_path_re = delegates(to=ImageDataLoaders.from_path_func)(ImageDataLoaders.from_path_re)\n",
    "ImageDataLoaders.from_name_re = delegates(to=ImageDataLoaders.from_name_func)(ImageDataLoaders.from_name_re)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This class should not be used directly, one of the factory methods should be preferred instead. All those factory methods accept as arguments:\n",
    "\n",
    "- `item_tfms`: one or several transforms applied to the items before batching them\n",
    "- `batch_tfms`: one or several transforms applied to the batches once they are formed\n",
    "- `bs`: the batch size\n",
    "- `val_bs`: the batch size for the validation `DataLoader` (defaults to `bs`)\n",
    "- `shuffle_train`: if we shuffle the training `DataLoader` or not\n",
    "- `device`: the PyTorch device to use (defaults to `default_device()`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"ImageDataLoaders.from_folder\" class=\"doc_header\"><code>ImageDataLoaders.from_folder</code><a href=\"__main__.py#L4\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>ImageDataLoaders.from_folder</code>(**`path`**, **`train`**=*`'train'`*, **`valid`**=*`'valid'`*, **`valid_pct`**=*`None`*, **`seed`**=*`None`*, **`vocab`**=*`None`*, **`item_tfms`**=*`None`*, **`batch_tfms`**=*`None`*, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`device`**=*`None`*)\n",
       "\n",
       "Create from imagenet style dataset in `path` with `train` and `valid` subfolders (or provide `valid_pct`)"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(ImageDataLoaders.from_folder)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If `valid_pct` is provided, a random split is performed (with an optional `seed`) by setting aside that percentage of the data for the validation set (instead of looking at the grandparents folder). If a `vocab` is passed, only the folders with names in `vocab` are kept.\n",
    "\n",
    "Here is an example loading a subsample of MNIST:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.MNIST_TINY)\n",
    "dls = ImageDataLoaders.from_folder(path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Passing `valid_pct` will ignore the valid/train folders and do a new random split:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Path('/home/jhoward/.fastai/data/mnist_tiny/train/7/7795.png'),\n",
       " Path('/home/jhoward/.fastai/data/mnist_tiny/valid/7/7654.png'),\n",
       " Path('/home/jhoward/.fastai/data/mnist_tiny/train/3/8830.png')]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dls = ImageDataLoaders.from_folder(path, valid_pct=0.2)\n",
    "dls.valid_ds.items[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"ImageDataLoaders.from_path_func\" class=\"doc_header\"><code>ImageDataLoaders.from_path_func</code><a href=\"__main__.py#L19\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>ImageDataLoaders.from_path_func</code>(**`path`**, **`fnames`**, **`label_func`**, **`valid_pct`**=*`0.2`*, **`seed`**=*`None`*, **`item_tfms`**=*`None`*, **`batch_tfms`**=*`None`*, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`device`**=*`None`*)\n",
       "\n",
       "Create from list of `fnames` in `path`s with `label_func`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(ImageDataLoaders.from_path_func)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The validation set is a random `subset` of `valid_pct`, optionally created with `seed` for reproducibility.\n",
    "\n",
    "Here is how to create the same `DataLoaders` on the MNIST dataset as the previous example with a `label_func`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fnames = get_image_files(path)\n",
    "def label_func(x): return x.parent.name\n",
    "dls = ImageDataLoaders.from_path_func(path, fnames, label_func)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is another example on the pets dataset. Here filenames are all in an \"images\" folder and their names have the form `class_name_123.jpg`. One way to properly label them is thus to throw away everything after the last `_`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"ImageDataLoaders.from_path_re\" class=\"doc_header\"><code>ImageDataLoaders.from_path_re</code><a href=\"__main__.py#L36\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>ImageDataLoaders.from_path_re</code>(**`path`**, **`fnames`**, **`pat`**, **`valid_pct`**=*`0.2`*, **`seed`**=*`None`*, **`item_tfms`**=*`None`*, **`batch_tfms`**=*`None`*, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`device`**=*`None`*)\n",
       "\n",
       "Create from list of `fnames` in `path`s with re expression `pat`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(ImageDataLoaders.from_path_re)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The validation set is a random subset of `valid_pct`, optionally created with `seed` for reproducibility.\n",
    "\n",
    "Here is how to create the same `DataLoaders` on the MNIST dataset as the previous example (you will need to change the initial two / by a \\ on Windows):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pat = r'/([^/]*)/\\d+.png$'\n",
    "dls = ImageDataLoaders.from_path_re(path, fnames, pat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"ImageDataLoaders.from_name_func\" class=\"doc_header\"><code>ImageDataLoaders.from_name_func</code><a href=\"__main__.py#L30\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>ImageDataLoaders.from_name_func</code>(**`path`**, **`fnames`**, **`label_func`**, **`valid_pct`**=*`0.2`*, **`seed`**=*`None`*, **`item_tfms`**=*`None`*, **`batch_tfms`**=*`None`*, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`device`**=*`None`*)\n",
       "\n",
       "Create from the name attrs of `fnames` in `path`s with `label_func`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(ImageDataLoaders.from_name_func)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The validation set is a random subset of `valid_pct`, optionally created with `seed` for reproducibility. This method does the same as `ImageDataLoaders.from_path_func` except `label_func` is applied to the name of each filenames, and not the full path."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"ImageDataLoaders.from_name_re\" class=\"doc_header\"><code>ImageDataLoaders.from_name_re</code><a href=\"__main__.py#L41\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>ImageDataLoaders.from_name_re</code>(**`path`**, **`fnames`**, **`pat`**, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`device`**=*`None`*)\n",
       "\n",
       "Create from the name attrs of `fnames` in `path`s with re expression `pat`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(ImageDataLoaders.from_name_re)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The validation set is a random subset of `valid_pct`, optionally created with `seed` for reproducibility. This method does the same as `ImageDataLoaders.from_path_re` except `pat` is applied to the name of each filenames, and not the full path."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"ImageDataLoaders.from_df\" class=\"doc_header\"><code>ImageDataLoaders.from_df</code><a href=\"__main__.py#L47\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>ImageDataLoaders.from_df</code>(**`df`**, **`path`**=*`'.'`*, **`valid_pct`**=*`0.2`*, **`seed`**=*`None`*, **`fn_col`**=*`0`*, **`folder`**=*`None`*, **`suff`**=*`''`*, **`label_col`**=*`1`*, **`label_delim`**=*`None`*, **`y_block`**=*`None`*, **`valid_col`**=*`None`*, **`item_tfms`**=*`None`*, **`batch_tfms`**=*`None`*, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`device`**=*`None`*)\n",
       "\n",
       "Create from `df` using `fn_col` and `label_col`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(ImageDataLoaders.from_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The validation set is a random subset of `valid_pct`, optionally created with `seed` for reproducibility. Alternatively, if your `df` contains a `valid_col`, give its name or its index to that argument (the column should have `True` for the elements going to the validation set). \n",
    "\n",
    "You can add an additional `folder` to the filenames in `df` if they should not be concatenated directly to `path`. If they do not contain the proper extensions, you can add `suff`. If your label column contains multiple labels on each row, you can use `label_delim` to warn the library you have a multi-label problem. \n",
    "\n",
    "`y_block` should be passed when the task automatically picked by the library is wrong, you should then give `CategoryBlock`, `MultiCategoryBlock` or `RegressionBlock`. For more advanced uses, you should use the data block API.\n",
    "\n",
    "The tiny mnist example from before also contains a version in a dataframe:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>train/3/7463.png</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>train/3/9829.png</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>train/3/7881.png</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>train/3/8065.png</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>train/3/7046.png</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               name  label\n",
       "0  train/3/7463.png      3\n",
       "1  train/3/9829.png      3\n",
       "2  train/3/7881.png      3\n",
       "3  train/3/8065.png      3\n",
       "4  train/3/7046.png      3"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "path = untar_data(URLs.MNIST_TINY)\n",
    "df = pd.read_csv(path/'labels.csv')\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is how to load it using `ImageDataLoaders.from_df`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = ImageDataLoaders.from_df(df, path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is another example with a multi-label problem:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fname</th>\n",
       "      <th>labels</th>\n",
       "      <th>is_valid</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>000005.jpg</td>\n",
       "      <td>chair</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>000007.jpg</td>\n",
       "      <td>car</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>000009.jpg</td>\n",
       "      <td>horse person</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>000012.jpg</td>\n",
       "      <td>car</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>000016.jpg</td>\n",
       "      <td>bicycle</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        fname        labels  is_valid\n",
       "0  000005.jpg         chair      True\n",
       "1  000007.jpg           car      True\n",
       "2  000009.jpg  horse person      True\n",
       "3  000012.jpg           car     False\n",
       "4  000016.jpg       bicycle      True"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "path = untar_data(URLs.PASCAL_2007)\n",
    "df = pd.read_csv(path/'train.csv')\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = ImageDataLoaders.from_df(df, path, folder='train', valid_col='is_valid')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that can also pass `2` to valid_col (the index, starting with 0)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"ImageDataLoaders.from_csv\" class=\"doc_header\"><code>ImageDataLoaders.from_csv</code><a href=\"__main__.py#L65\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>ImageDataLoaders.from_csv</code>(**`path`**, **`csv_fname`**=*`'labels.csv'`*, **`header`**=*`'infer'`*, **`delimiter`**=*`None`*, **`valid_pct`**=*`0.2`*, **`seed`**=*`None`*, **`fn_col`**=*`0`*, **`folder`**=*`None`*, **`suff`**=*`''`*, **`label_col`**=*`1`*, **`label_delim`**=*`None`*, **`y_block`**=*`None`*, **`valid_col`**=*`None`*, **`item_tfms`**=*`None`*, **`batch_tfms`**=*`None`*, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`device`**=*`None`*)\n",
       "\n",
       "Create from `path/csv_fname` using `fn_col` and `label_col`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(ImageDataLoaders.from_csv)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Same as `ImageDataLoaders.from_df` after loading the file with `header` and `delimiter`.\n",
    "\n",
    "Here is how to load the same dataset as before with this method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = ImageDataLoaders.from_csv(path, 'train.csv', folder='train', valid_col='is_valid')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"ImageDataLoaders.from_lists\" class=\"doc_header\"><code>ImageDataLoaders.from_lists</code><a href=\"__main__.py#L71\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>ImageDataLoaders.from_lists</code>(**`path`**, **`fnames`**, **`labels`**, **`valid_pct`**=*`0.2`*, **`seed`**:`int`=*`None`*, **`y_block`**=*`None`*, **`item_tfms`**=*`None`*, **`batch_tfms`**=*`None`*, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`device`**=*`None`*)\n",
       "\n",
       "Create from list of `fnames` and `labels` in `path`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(ImageDataLoaders.from_lists)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The validation set is a random subset of `valid_pct`, optionally created with `seed` for reproducibility. `y_block` can be passed to specify the type of the targets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.PETS)\n",
    "fnames = get_image_files(path/\"images\")\n",
    "labels = ['_'.join(x.name.split('_')[:-1]) for x in fnames]\n",
    "dls = ImageDataLoaders.from_lists(path, fnames, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class SegmentationDataLoaders(DataLoaders):\n",
    "    \"Basic wrapper around several `DataLoader`s with factory methods for segmentation problems\"\n",
    "    @classmethod\n",
    "    @delegates(DataLoaders.from_dblock)\n",
    "    def from_label_func(cls, path, fnames, label_func, valid_pct=0.2, seed=None, codes=None, item_tfms=None, batch_tfms=None, **kwargs):\n",
    "        \"Create from list of `fnames` in `path`s with `label_func`.\"\n",
    "        dblock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=codes)),\n",
    "                           splitter=RandomSplitter(valid_pct, seed=seed),\n",
    "                           get_y=label_func,\n",
    "                           item_tfms=item_tfms,\n",
    "                           batch_tfms=batch_tfms)\n",
    "        res = cls.from_dblock(dblock, fnames, path=path, **kwargs)\n",
    "        return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"SegmentationDataLoaders.from_label_func\" class=\"doc_header\"><code>SegmentationDataLoaders.from_label_func</code><a href=\"__main__.py#L4\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>SegmentationDataLoaders.from_label_func</code>(**`path`**, **`fnames`**, **`label_func`**, **`valid_pct`**=*`0.2`*, **`seed`**=*`None`*, **`codes`**=*`None`*, **`item_tfms`**=*`None`*, **`batch_tfms`**=*`None`*, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`device`**=*`None`*)\n",
       "\n",
       "Create from list of `fnames` in `path`s with `label_func`."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(SegmentationDataLoaders.from_label_func)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The validation set is a random subset of `valid_pct`, optionally created with `seed` for reproducibility. `codes` contain the mapping index to label."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.CAMVID_TINY)\n",
    "fnames = get_image_files(path/'images')\n",
    "def label_func(x): return path/'labels'/f'{x.stem}_P{x.suffix}'\n",
    "codes = np.loadtxt(path/'codes.txt', dtype=str)\n",
    "    \n",
    "dls = SegmentationDataLoaders.from_label_func(path, fnames, label_func, codes=codes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
